"""AWS credential management."""

import functools
import typing
from typing import cast

import boto3
from cki_lib import logger
from cki_lib import misc

from . import secrets
from . import token

LOGGER = logger.get_logger(__name__)


@functools.cache
def _current_account(profile_name: str | None = None) -> str:
    sts = boto3.Session(profile_name=profile_name).client("sts")
    return cast(str, sts.get_caller_identity()["Account"])


@functools.cache
def _iam(profile_name: str | None = None) -> boto3.client:
    return boto3.Session(profile_name=profile_name).client("iam")


@token.register_token('aws_secret_access_key')
class AwsSecretAccessKey(token.Token):
    """AWS secret access key."""

    clean_active_versions = 1

    def _has_metadata_profile(self) -> bool:
        """Check whether all tokens in the group have profile metadata."""
        return all(
            misc.get_nested_key(token, "meta/profile_name", "")
            for token in self.active_token_data.values()
        )

    def prepare(self) -> None:
        """Prepare the token group for rotation.

        Overrides to skip prepare if there's no "profile_name" in the metadata.
        """
        if not self._has_metadata_profile():
            LOGGER.info("Can't prepare %s. Missing 'profile_name' in metadata", self.token_group)
            return None

        return super().prepare()

    def rotate(self, token_version: str) -> None:
        """Rotate a token, and update the secrets file.

        Overrides to skip rotation if there's no "profile_name" in the metadata.
        """
        if not self._has_metadata_profile():
            LOGGER.info("Can't rotate %s. Missing 'profile_name' in metadata", self.token_group)
            return None

        return super().rotate(token_version)

    def _create_token(self, token_version: str, meta: dict[str, typing.Any]) -> str:
        """Create an AWS key."""
        iam = _iam(meta.get("profile_name"))
        user_name = meta['user_name']
        access_key = iam.create_access_key(UserName=user_name)["AccessKey"]
        meta.update({
            'access_key_id': access_key['AccessKeyId'],
            'arn': f'arn:aws:iam::{meta["account"]}:user/{user_name}',
            'created_at': misc.ensure_tz_utc(access_key['CreateDate']).isoformat(),
        })
        return cast(str, access_key['SecretAccessKey'])

    def _destroy_token(self, token_version: str, meta: dict[str, typing.Any]) -> None:
        """Destroy an AWS key."""
        iam = _iam(meta.get("profile_name"))
        iam.delete_access_key(UserName=meta["user_name"], AccessKeyId=meta["access_key_id"])

    def _update_token(self, token_version: str, meta: dict[str, typing.Any]) -> None:
        """Update the secret meta information about AWS tokens."""
        token_name = self.full_token_name(token_version)

        if endpoint_url := meta.get('endpoint_url'):
            LOGGER.debug('Token meta update for %s not supported', endpoint_url)
            return
        if not meta.get('account') or not meta.get('arn'):
            response = (
                boto3.Session()
                .client(
                    "sts",
                    aws_access_key_id=meta["access_key_id"],
                    aws_secret_access_key=secrets.secret(f"{token_name}"),
                )
                .get_caller_identity()
            )
            meta.update({
                'account': response['Account'],
                'arn': response['Arn'],
            })
        profile_name = meta.get("profile_name")
        if (profile_account := _current_account(profile_name)) != meta["account"]:
            LOGGER.debug("Token account %s != %s", meta["account"], profile_account)
            return
        iam = _iam(profile_name)
        if not meta.get('user_name'):
            response = iam.get_access_key_last_used(AccessKeyId=meta["access_key_id"])
            meta['user_name'] = response['UserName']
        response = next(
            r
            for r in iam.list_access_keys(UserName=meta["user_name"])["AccessKeyMetadata"]
            if r["AccessKeyId"] == meta["access_key_id"]
        )
        meta.update({
            'active': response['Status'] == 'Active',
            'created_at': misc.ensure_tz_utc(response['CreateDate']).isoformat(),
        })
